import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_


class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


class CfC_ODEFunc_Cell(nn.Module):

    def __init__(self, dim, hidden_ratio=4, backbone_layers=2, mode='default'):
        super().__init__()
        self.dim = dim
        self.mode = mode
        hidden_dim = dim * hidden_ratio

        self.param_net = nn.Sequential(
            nn.LayerNorm(dim),
            *self._make_backbone(dim, hidden_dim, backbone_layers),
            nn.Linear(hidden_dim, 6 * dim),
        )
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.eps = 1e-8
        self.min_gate = 0.01

    def forward(self, t, x):
        B, C, H, W = x.shape
        x_flat = rearrange(x, 'b c h w -> (b h w) c')
        t = self._process_time(t, x)
        modulated_input = self.lecun_tanh(s * x_flat + b1)
        gate_input = self.sigmoid(r * x_flat + b2)
        if self.mode == 'no_gate':
            input_term = modulated_input
        else:
            input_term = (1 - gate_input) * modulated_input + gate_input * self.tanh(x_flat)

        h = decay1 * x_flat + (1 - decay2) * input_term
        h = rearrange(h, '(b h w) c -> b c h w', b=B, h=H, w=W)
        return h.clamp(-10.0, 10.0)


class PatchMerging(nn.Module):

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        H, W = self.input_resolution
        if x.dim() == 4:
            B, H_x, W_x, C = x.shape
            assert H_x == H and W_x == W, "Spatial dimensions mismatch"
            x = x.view(B, H, W, C)
        else:
            B, L, C = x.shape
            assert L == H * W, "Input feature has wrong size"
            x = x.view(B, H, W, C)

        if H % 2 != 0 or W % 2 != 0:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B (H/2*W/2) 4*C

        x = self.norm(x)
        x = self.reduction(x)
        return x


class PatchEmbed(nn.Module):

    def __init__(self, img_size=224, patch_size=4, in_chans=3,
                 embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0],
                          img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        x = self.norm(x)
        return x


class WindowPartition(nn.Module):
    def __init__(self, window_size=7, shift_size=0):
        super().__init__()
        self.window_size = window_size
        self.shift_size = shift_size % window_size

    def forward(self, x):
        B, H, W, C = x.shape

        if self.shift_size > 0:
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))

        x = x.view(B,
                   H // self.window_size, self.window_size,
                   W // self.window_size, self.window_size, C)
        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size, self.window_size, C)
        return windows, (H, W)

    def reverse(self, windows, orig_size):
        H, W = orig_size
        B = windows.shape[0] // ((H * W) // (self.window_size ** 2))

        x = windows.view(B,
                         H // self.window_size,
                         W // self.window_size,
                         self.window_size,
                         self.window_size, -1)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)

        if self.shift_size > 0:
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))

        x = x[:, :H, :W, :].contiguous()
        return x


class SwinLiquidAttention(nn.Module):
    def __init__(self, dim, num_heads=8, window_size=7, shift_size=0):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.shift_size = shift_size
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.register_buffer("relative_position_index", self.create_position_index(window_size))
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) ** 2, num_heads))
        trunc_normal_(self.relative_position_bias_table, std=.02)

        self.time_net = nn.Sequential(
            nn.Linear(num_heads + dim, dim // 4),
            nn.GELU(),
            nn.LayerNorm(dim // 4),
            nn.Linear(dim // 4, 1),
            nn.Sigmoid()
        )
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.window_partition = WindowPartition(window_size, shift_size)
        self.cfc_ode = CfC_ODEFunc_Cell(dim, hidden_ratio=2)

    def create_position_index(self, window_size):
        coords = torch.stack(torch.meshgrid(
            torch.arange(window_size),
            torch.arange(window_size), indexing='ij')).flatten(1)
        relative_coords = coords[:, :, None] - coords[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size - 1
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        return relative_coords.sum(-1).view(-1)

    def forward(self, x):
        B, H, W, C = x.shape
        shortcut = x
        windows, orig_size = self.window_partition(x)
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index].view(self.window_size ** 2,
                                               self.window_size ** 2, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1)  # nH, Wh*Ww, Wh*Ww

        N, ws, _, C = windows.shape
        windows = windows.view(-1, ws * ws, C)
        bias = self.relative_position_bias_table[self.relative_position_index].view(
            ws * ws, ws * ws, -1).permute(2, 0, 1)
        bias_mean = bias.mean(dim=(1, 2))  # [num_heads]
        window_mean = windows.mean(dim=1)  # [N, C]
        time_input = torch.cat([bias_mean.repeat(N, 1), window_mean], dim=1)
        t = self.time_net(time_input).view(-1, 1, 1, 1)  # [N, 1, 1, 1]

        qkv = self.qkv(windows).reshape(N, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale + relative_position_bias.unsqueeze(0)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(N, self.window_size, self.window_size, C)

        x = self.proj(x)
        x = self.window_partition.reverse(x, orig_size)
        x = self.cfc_ode(t, x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

        return x + shortcut


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None,
                 out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.norm = nn.LayerNorm(hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class LiquidSwinBlock(nn.Module):

    def __init__(self, dim, num_heads, window_size=7,
                 shift_size=0, mlp_ratio=4., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = SwinLiquidAttention(dim, num_heads, window_size, shift_size)
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = Mlp(dim, hidden_features=int(dim * mlp_ratio))
        self.liquid_conv = CfC_ODEFunc_Cell(dim)

    def forward(self, x, t=1.0):
        x_perm = x.permute(0, 2, 3, 1)
        attn_out = self.attn(self.norm1(x_perm)).permute(0, 3, 1, 2)
        x = x + self.drop_path(attn_out)

        mlp_input = x.permute(0, 2, 3, 1)
        mlp_out = self.mlp(self.norm2(mlp_input)).permute(0, 3, 1, 2)
        x = x + self.drop_path(mlp_out)

        return self.liquid_conv(t, x)


class LiquidSwinTransformer(nn.Module):
    def __init__(self, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2],
                 num_heads=[3, 6, 12, 24], window_size=7, img_size=224, num_classes=1000,
                 return_interm_layers=True):
        super().__init__()
        self.return_interm_layers = return_interm_layers
        self.num_stages = len(depths)
        self.patch_embed = PatchEmbed(img_size, 4, in_chans, embed_dim)
        self.layers = nn.ModuleList()
        current_res = img_size // 4

        self.stage_indices = []
        for i in range(self.num_stages):
            stage_blocks = []
            for d in range(depths[i]):
                stage_blocks.append(LiquidSwinBlock(
                    dim=embed_dim * (2 ** i),
                    num_heads=num_heads[i],
                    window_size=window_size,
                    shift_size=0 if d % 2 == 0 else window_size // 2,
                    mlp_ratio=4.,
                    drop_path=0.1 * (d / depths[i])
                ))
            self.layers.append(nn.Sequential(*stage_blocks))
            self.stage_indices.append(len(self.layers) - 1)

            # Patch merging
            if i != self.num_stages - 1:
                self.layers.append(PatchMerging(
                    input_resolution=(current_res, current_res),
                    dim=embed_dim * (2 ** i)
                ))
                current_res = current_res // 2

        if not self.return_interm_layers:
            self.norm = nn.SyncBatchNorm(embed_dim * (2 ** (self.num_stages - 1)))
            self.head = nn.Linear(embed_dim * (2 ** (self.num_stages - 1)), num_classes)

    def forward_features(self, x):
        features = []
        x = self.patch_embed(x)
        B, L, C = x.shape
        H = W = int(L ** 0.5)
        x = x.view(B, H, W, C).permute(0, 3, 1, 2)  # [B, C, H, W]

        for i, layer in enumerate(self.layers):
            if isinstance(layer, PatchMerging):
                x = x.permute(0, 2, 3, 1)  # [B, H, W, C]
                x = layer(x)
                B, L_new, C_new = x.shape
                H = W = int(L_new ** 0.5)
                x = x.view(B, C_new, H, W)  # [B, C, H, W]
            else:
                x = layer(x)

            if i in self.stage_indices:
                features.append(x)

        return features

    def forward(self, x):
        features = self.forward_features(x)

        if self.return_interm_layers:
            return features
        else:
            x = self.norm(features[-1])
            x = x.mean(dim=(2, 3))
            x = self.head(x)
            return x


class IQA_Model(nn.Module):
    def __init__(self,
                 return_interm_layers: bool = True,
                 pretrained: bool = False,
                 pretrained_model_path: str = "",
                 infer: bool = False,
                 infer_model_path: str = "",
                 **kwargs):
        super().__init__()


        self.backbone = LiquidSwinTransformer(
            return_interm_layers=return_interm_layers,
            **kwargs
        )

        self.iqa_tail = SimplifiedIQATail()

        if pretrained:
            self.load_pretrained_weights(pretrained_model_path)

        if infer:
            self.load_infer_weights(infer_model_path)

    def load_pretrained_weights(self, model_path: str):
        checkpoint = torch.load(model_path, map_location='cpu')
        pretrained_dict = checkpoint.get('model', checkpoint)

        new_dict = {}
        for k, v in pretrained_dict.items():
            new_key = k.replace('model.', '').replace('module.', '')
            new_dict[new_key] = v

        model_dict = self.backbone.state_dict()
        filtered_dict = {
            k: v for k, v in new_dict.items()
            if k in model_dict
               and v.shape == model_dict[k].shape
               and 'head' not in k
        }
        self.backbone.load_state_dict(filtered_dict, strict=False)

    def load_infer_weights(self, model_path: str):
        checkpoint = torch.load(model_path, map_location='cpu')
        state_dict = checkpoint.get('model', checkpoint)
        new_dict = {}
        for k, v in state_dict.items():
            new_key = k.replace('model.', '').replace('module.', '')
            new_dict[new_key] = v
        self.load_state_dict(new_dict, strict=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        return self.iqa_tail(features)


class IQA_Tail(nn.Module):

    def __init__(self, embed_dims=[96, 192, 384, 768], hidden_dim=384, num_queries=6):
        super().__init__()

        self.proj_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(dim, hidden_dim, 1),
                nn.GELU(),
                nn.GroupNorm(1, hidden_dim)
            ) for dim in embed_dims
        ])

        self.fusion_encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=6,
                dim_feedforward=hidden_dim * 4,
                activation=F.gelu,
                batch_first=True
            ),
            num_layers=2
        )

        self.ref_proj = nn.Linear(embed_dims[-1], hidden_dim)

        self.decoder = nn.TransformerDecoder(
            decoder_layer=nn.TransformerDecoderLayer(
                d_model=hidden_dim,
                nhead=6,
                dim_feedforward=hidden_dim * 4,
                activation=F.gelu,
                batch_first=True
            ),
            num_layers=2
        )
        self.queries = nn.Parameter(torch.randn(1, num_queries, hidden_dim))
        self.score_head = nn.Linear(hidden_dim, 1)

        trunc_normal_(self.queries, std=0.02)

    def forward(self, features):
        assert len(features) == 4,

        proj_feats = []
        for feat, proj in zip(features, self.proj_layers):
            feat = proj(feat)

            if feat.shape[-1] != 7:
                feat = F.interpolate(feat, size=7, mode='bilinear', align_corners=False)

            proj_feats.append(feat.flatten(2).permute(0, 2, 1))  # [B, 49, hidden_dim]

        fused = torch.cat(proj_feats, dim=1)
        fused = self.fusion_encoder(fused)

        ref = F.adaptive_avg_pool2d(features[-1], 1).flatten(1)
        ref = self.ref_proj(ref).unsqueeze(1)

        B = fused.shape[0]
        queries = self.queries.expand(B, -1, -1) + ref
        output = self.decoder(queries, fused)

        return self.score_head(output.mean(dim=1))  # [B, 1]


class IQATail(nn.Module):
    def __init__(self):
        super(IQATail, self).__init__()

        self.channel_adjust = nn.ModuleList([
            nn.Conv2d(96, 256, 1),
            nn.Conv2d(192, 256, 1),
            nn.Conv2d(384, 256, 1),
            nn.Conv2d(768, 256, 1)
        ])

        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=4, padding=3),  # 224x224 -> 56x56
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 256, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.feature_fusion = nn.Sequential(
            nn.Conv2d(256 * 5, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.regressor = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1)
        )

    def forward(self, features, img):
        feat1, feat2, feat3, feat4 = features

        feat_list = []
        for i, feat in enumerate([feat1, feat2, feat3, feat4]):
            feat = self.channel_adjust[i](feat)
            if i > 0:
                feat = F.interpolate(feat, size=56, mode='bilinear', align_corners=False)
            feat_list.append(feat)

        img_feat = self.image_encoder(img)
        feat_list.append(img_feat)
        fused = torch.cat(feat_list, dim=1)
        fused = self.feature_fusion(fused)

        return self.regressor(fused)


class SimplifiedIQATail(nn.Module):
    def __init__(self):
        super(SimplifiedIQATail, self).__init__()
        self.regressor = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(768, 1)
        )

    def forward(self, features):
        x = features[-1]
        return self.regressor(x)

if __name__ == "__main__":
    model = IQA_Model(
        return_interm_layers=True,
        pretrained=False,
        pretrained_model_path="",
        infer=False,
        infer_model_path="",
        in_chans=3,
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        img_size=224,
        num_classes=1000
    )

    dummy_input = torch.randn(2, 3, 224, 224)
    output = model(dummy_input)
    print("Output shape:", output.shape)  # [2, 1]